{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "This example demonstrates the workflow to download a publicly available TF \n",
    "model, strip part of it for inference, and convert it to CoreML using the \n",
    "tfcoreml converter. \n",
    "\n",
    "Stripping part of the TF model may be useful when:\n",
    "(1) the TF model contains input data pre-processing mechanisms that are \n",
    "suitable for training / unsupported by CoreML\n",
    "(2) the TF model has ops only used in training time\n",
    "\n",
    "We use an inception v3 model provided by Google, which can be downloaded \n",
    "at this URL:\n",
    "\n",
    "https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip\n",
    "\"\"\"\n",
    "from __future__ import print_function\n",
    "import  os, sys, zipfile\n",
    "from os.path import dirname\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow.core.framework import graph_pb2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the model and class label package\n",
    "def download_file_and_unzip(url, dir_path='.'):\n",
    "    \"\"\"Download the frozen TensorFlow model and unzip it.\n",
    "    url - The URL address of the frozen file\n",
    "    dir_path - local directory\n",
    "    \"\"\"\n",
    "    if not os.path.exists(dir_path):\n",
    "        os.makedirs(dir_path)\n",
    "    k = url.rfind('/')\n",
    "    fname = url[k+1:]\n",
    "    fpath = os.path.join(dir_path, fname)\n",
    "\n",
    "    if not os.path.exists(fpath):\n",
    "        if sys.version_info[0] < 3:\n",
    "            import urllib\n",
    "            urllib.urlretrieve(url, fpath)\n",
    "        else:\n",
    "            import urllib.request\n",
    "            urllib.request.urlretrieve(url, fpath)\n",
    "\n",
    "    zip_ref = zipfile.ZipFile(fpath, 'r')\n",
    "    zip_ref.extractall(dir_path)\n",
    "    zip_ref.close()\n",
    "\n",
    "inception_v3_url = 'https://storage.googleapis.com/download.tensorflow.org/models/inception_dec_2015.zip'\n",
    "download_file_and_unzip(inception_v3_url);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the TF graph definition\n",
    "tf_model_path = './tensorflow_inception_graph.pb'\n",
    "with open(tf_model_path, 'rb') as f:\n",
    "    serialized = f.read()\n",
    "tf.reset_default_graph()\n",
    "original_gdef = tf.GraphDef()\n",
    "original_gdef.ParseFromString(serialized)\n",
    "\n",
    "# For demonstration purpose we show the first 15 ops the TF model\n",
    "with tf.Graph().as_default() as g:\n",
    "    tf.import_graph_def(original_gdef, name='')\n",
    "    ops = g.get_operations()\n",
    "    for i in range(15):\n",
    "        print('op id {} : op name: {}, op type: \"{}\"'.format(str(i),ops[i].name, ops[i].type));\n",
    "\n",
    "# This Inception model uses DecodeJpeg op to read from JPEG images\n",
    "# encoded as string Tensors. You can visualize it with TensorBoard,\n",
    "# but we're omitting it here. For deployment we need to remove the\n",
    "# JPEG decoder and related ops, and replace them with a placeholder\n",
    "# where we can feed image data in. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Strip the JPEG decoder and preprocessing part of TF model\n",
    "# In this model, the actual op that feeds pre-processed image into \n",
    "# the network is 'Mul'. The op that generates probabilities per\n",
    "# class is 'softmax/logits'\n",
    "# To figure out what are inputs/outputs for your own model\n",
    "# You can use use TensorFlow's summarize_graph or TensorBoard\n",
    "# Visualization tool for your own models.\n",
    "\n",
    "from tensorflow.python.tools import strip_unused_lib\n",
    "from tensorflow.python.framework import dtypes\n",
    "from tensorflow.python.platform import gfile\n",
    "input_node_names = ['Mul']\n",
    "output_node_names = ['softmax/logits']\n",
    "gdef = strip_unused_lib.strip_unused(\n",
    "        input_graph_def = original_gdef,\n",
    "        input_node_names = input_node_names,\n",
    "        output_node_names = output_node_names,\n",
    "        placeholder_type_enum = dtypes.float32.as_datatype_enum)\n",
    "# Save it to an output file\n",
    "frozen_model_file = './inception_v3.pb'\n",
    "with gfile.GFile(frozen_model_file, \"wb\") as f:\n",
    "    f.write(gdef.SerializeToString())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now we have a TF model ready to be converted to CoreML\n",
    "import tfcoreml\n",
    "# Supply a dictionary of input tensors' name and shape (with \n",
    "# batch axis)\n",
    "input_tensor_shapes = {\"Mul:0\":[1,299,299,3]} # batch size is 1\n",
    "# Output CoreML model path\n",
    "coreml_model_file = './inception_v3.mlmodel'\n",
    "# The TF model's ouput tensor name\n",
    "output_tensor_names = ['softmax/logits:0']\n",
    "\n",
    "# Call the converter. This may take a while\n",
    "coreml_model = tfcoreml.convert(\n",
    "        tf_model_path=frozen_model_file,\n",
    "        mlmodel_path=coreml_model_file,\n",
    "        input_name_shape_dict=input_tensor_shapes,\n",
    "        output_feature_names=output_tensor_names,\n",
    "        image_input_names = ['Mul:0'],\n",
    "        red_bias = -1,\n",
    "        green_bias = -1,\n",
    "        blue_bias = -1,\n",
    "        image_scale = 2.0/255.0)\n",
    "\n",
    "# MLModel saved at location: ./inception_v3.mlmodel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now we're ready to test out the CoreML model with a real image!\n",
    "# Load an image\n",
    "import PIL\n",
    "import requests\n",
    "from io import BytesIO\n",
    "from matplotlib.pyplot import imshow\n",
    "# This is an image of a golden retriever from Wikipedia\n",
    "img_url = 'https://upload.wikimedia.org/wikipedia/commons/9/93/Golden_Retriever_Carlos_%2810581910556%29.jpg'\n",
    "response = requests.get(img_url)\n",
    "%matplotlib inline\n",
    "img = PIL.Image.open(BytesIO(response.content))\n",
    "imshow(np.asarray(img))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run CoreML prediction\n",
    "# Pay attention to '__0'. We change ':0' to '__0' to make sure \n",
    "# MLModel's generated Swift/Obj-C code is semantically correct\n",
    "img = img.resize([299,299], PIL.Image.ANTIALIAS)\n",
    "coreml_inputs = {'Mul__0': img}\n",
    "coreml_output = coreml_model.predict(coreml_inputs, useCPUOnly=False)\n",
    "probs = coreml_output['softmax__logits__0'].flatten()\n",
    "label_idx = np.argmax(probs)\n",
    "\n",
    "# This label file comes with the model\n",
    "label_file = 'imagenet_comp_graph_label_strings.txt' \n",
    "with open(label_file) as f:\n",
    "    labels = f.readlines()\n",
    "print('Label = {}'.format(labels[label_idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# And that's the end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "celltoolbar": "Raw Cell Format",
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
